package com.bayesianWeka;

import weka.classifiers.AbstractClassifier;
import weka.core.*;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;

/**
 * @Author: Decade
 * @Date: 2021/3/5 15:53
 */
public class TargetLearningKDB extends AbstractClassifier {

    /*类标签的个数*/
    private int m_NumClasses;

    /*属性的个数，在weka中，本值为预测属性和类属性之和*/
    private int m_NumAttributes;

    /*训练集类标签的索引*/
    private int m_ClassIndex;

    /** KDB模型的父结点 */
    private int[][] m_parents;

    /*通用工具类*/
    private CommonUtils commonUtils;

    /**
     * 训练阶段，本阶段主要做两件事情：1. 变量的初始化 数组空间的申请 2. General版本模型分类器的构建
     * @param instances 测试集对象
     */
    @Override
    public void buildClassifier(Instances instances) throws Exception {

        commonUtils = new CommonUtils(instances);
        commonUtils.CalculateCountXxxy();

        m_ClassIndex = commonUtils.getM_ClassIndex();          //类标签的索引
        m_NumAttributes = commonUtils.getM_NumAttributes();    //属性个数 算上类标签的个数
        m_NumClasses = commonUtils.getM_NumClasses();          //类标签的个数
        m_parents = new int[m_NumAttributes][2];

        /** KDB的模型构建过程 */
        /* 计算互信息，并根据互信息的值对属性进行排序 */
        double[] MiXy = commonUtils.getMutualInfor();   //互信息
        Point[] points = new Point[m_NumAttributes-1];
        for (int i = 0; i < m_NumAttributes-1; i++) {
            points[i] = new Point(i,MiXy[i]);
        }
        Arrays.sort(points);
//        System.out.println(Arrays.toString(points));

        for (int i = 0; i < m_NumAttributes; i++) {
            Arrays.fill(m_parents[i],-1);       //先初始化为-1
        }
        if (m_NumAttributes - 1 < 3){
            //如果属性的值个数小于3个时，只需要给排第二位的属性设置父结点即可
            m_parents[points[1].getPointNum()][0] = points[0].getPointNum();
        } else {
            //正常情况下基本都是大于等于3个属性值的
            m_parents[points[1].getPointNum()][0] = points[0].getPointNum();    //第一个节点第一个父结点为第0个节点
            m_parents[points[2].getPointNum()][0] = points[0].getPointNum();    //第二个节点第一个父结点为第0个节点
            m_parents[points[2].getPointNum()][1] = points[1].getPointNum();    //第二个节点第二个父结点为第1个节点
        }

        ArrayList<Integer> available = new ArrayList<>();
        ArrayList<Integer> inMap = new ArrayList<>();          //来两个集合，用来存放入图的结点和未入图的结点

        for (int i = 0; i < m_NumAttributes-1; i++) {
            if (m_NumAttributes - 1 <= 3)    //如果属性个数小于等于3  那就已经完成模型构建了 不用进行下一步了
                break;
            if (i < 3)
                inMap.add(points[i].getPointNum());
            else
                available.add(points[i].getPointNum());
        }
//        available.forEach(System.out::println);
//        System.out.println("***");
//        inMap.forEach(System.out::println);

        double[][] CmiXxy = commonUtils.getConditionMutualInfor();  //条件互信息
        //接下来就是从数组中找两个最大的
        Iterator<Integer> itAvailable = available.iterator();
        while (!available.isEmpty()){
            //取available集合中的第一个
            int availablePoint = itAvailable.next();
            //寻找这个点的两个父结点
            double maxCmi = -Double.MAX_VALUE;
            double secondCmi = -Double.MAX_VALUE;
            int firstParent = -1;
            int secondParent = -1;
            for (int itInmap : inMap){
                //遍历inmap中的每一个元素
                if (CmiXxy[availablePoint][itInmap] > maxCmi){
                    secondCmi = maxCmi;
                    maxCmi = CmiXxy[availablePoint][itInmap];
                    firstParent = itInmap;
                }
                else if (CmiXxy[availablePoint][itInmap] > secondCmi && CmiXxy[availablePoint][itInmap] <= maxCmi){
                    secondCmi = CmiXxy[availablePoint][itInmap];
                }
            }
            for (int itInmap : inMap){
                if (CmiXxy[availablePoint][itInmap] == secondCmi)
                    secondParent = itInmap;
            }
            m_parents[availablePoint][0] = firstParent;
            m_parents[availablePoint][1] = secondParent;
            itAvailable.remove();
            inMap.add(availablePoint);
        }
//        for (int i = 0; i < m_NumAttributes - 1; i++) {
//            System.out.println(Arrays.toString(m_parents[i]));
//        }
    }

    /**
     * 分类/决策阶段，本阶段主要做两件事：
     * 1. 给定测试样本，计算联合概率，并返回归一化后的联合概率(即后验概率)，用于分类器做决策
     * 2. Local版本分类器的构建
     * @param instance 一个测试样本实例
     * @return 后验概率
     */
    public double[] distributionForInstance(Instance instance) throws Exception {

        /** LocalKDB模型的父结点 */
        int[][] m_parentsLocal;
        m_parentsLocal = new int[m_NumAttributes][2];

        /** LocalKDB的模型构建过程 */
        /* 计算互信息，并根据互信息的值对属性进行排序 */
        double[] localMiXy = commonUtils.getMutualInforLocal(instance);  //互信息
        Point[] points = new Point[m_NumAttributes-1];
        for (int i = 0; i < m_NumAttributes-1; i++) {
            points[i] = new Point(i,localMiXy[i]);
        }
        Arrays.sort(points);
//        System.out.println(Arrays.toString(points));

        for (int i = 0; i < m_NumAttributes; i++) {
            Arrays.fill(m_parentsLocal[i],-1);       //先初始化为-1
        }
        if (m_NumAttributes - 1 < 3){
            //如果属性的值个数小于3个时，只需要给排第二位的属性设置父结点即可
            m_parentsLocal[points[1].getPointNum()][0] = points[0].getPointNum();
        } else {
            //正常情况下基本都是大于等于3个属性值的
            m_parentsLocal[points[1].getPointNum()][0] = points[0].getPointNum();    //第一个节点第一个父结点为第0个节点
            m_parentsLocal[points[2].getPointNum()][0] = points[0].getPointNum();    //第二个节点第一个父结点为第0个节点
            m_parentsLocal[points[2].getPointNum()][1] = points[1].getPointNum();    //第二个节点第二个父结点为第1个节点
        }

        ArrayList<Integer> available = new ArrayList<>();
        ArrayList<Integer> inMap = new ArrayList<>();          //来两个集合，用来存放入图的结点和未入图的结点

        for (int i = 0; i < m_NumAttributes-1; i++) {
            if (m_NumAttributes - 1 <= 3)    //如果属性个数小于等于3  那就已经完成模型构建了 不用进行下一步了
                break;
            if (i < 3)
                inMap.add(points[i].getPointNum());
            else
                available.add(points[i].getPointNum());
        }

        double[][] CmiXxyLocal = commonUtils.getConditionMutualInforLocal(instance);  //条件互信息
//        DecimalFormat df = new DecimalFormat("0.00000");
//        System.out.println("----------条件互信息矩阵-------------");
//        for (int att1 = 0; att1 < m_NumAttributes; att1++) {
//            for (int att2 = 0; att2 < m_NumAttributes; att2++) {
//                if (att1 != m_ClassIndex && att2 != m_ClassIndex)
//                    System.out.print(df.format(CmiXxyLocal[att1][att2]) + "   ");
//            }
//            System.out.println();
//        }
//        System.out.println("----------------------------");

        //接下来就是从数组中找两个最大的
        Iterator<Integer> itAvailable = available.iterator();
        while (!available.isEmpty()){
            //取available集合中的第一个
            int availablePoint = itAvailable.next();
            //寻找这个点的两个父结点
            double maxCmi = -Double.MAX_VALUE;
            double secondCmi = -Double.MAX_VALUE;
            int firstParent = -1;
            int secondParent = -1;
            for (int itInmap : inMap){
                //遍历inmap中的每一个元素
                if (CmiXxyLocal[availablePoint][itInmap] > maxCmi){
                    secondCmi = maxCmi;
                    maxCmi = CmiXxyLocal[availablePoint][itInmap];
                    firstParent = itInmap;
                }
                else if (CmiXxyLocal[availablePoint][itInmap] > secondCmi && CmiXxyLocal[availablePoint][itInmap] <= maxCmi){
                    secondCmi = CmiXxyLocal[availablePoint][itInmap];
                }
            }
            for (int itInmap : inMap){
                if (CmiXxyLocal[availablePoint][itInmap] == secondCmi)
                    secondParent = itInmap;
            }
            m_parentsLocal[availablePoint][0] = firstParent;
            m_parentsLocal[availablePoint][1] = secondParent;
            itAvailable.remove();
            inMap.add(availablePoint);
        }

//        for (int i = 0; i < m_NumAttributes - 1; i++) {
//            System.out.println(Arrays.toString(m_parentsLocal[i]));
//        }

        double[] probs = new double[m_NumClasses];
        double[] probsLocal = new double[m_NumClasses];
        for (int i = 0; i < m_NumClasses; i++) {
            probs[i] = commonUtils.P(i);        //P(y)
            probsLocal[i] = commonUtils.P(i);        //P(y)
        }
        //general模型计算联合概率
        for (int x1 = 0; x1 < m_NumAttributes; x1++) {
            if(x1 == m_ClassIndex)
                continue;
            int parent1 = m_parents[x1][0];              //父结点1的序号
            int parent2 = m_parents[x1][1];              //父结点2的序号

            for (int y = 0; y < m_NumClasses; y++) {
                if (parent1 == -1){
                    //说明是互信息最大的那个点，联合概率乘P(x1|y)
                    probs[y] *= commonUtils.P(y);
                }
                else if (parent2 == -1){
                    //parent1不为-1   parent2为-1 说明是互信息第二大的那个点，联合概率乘P(x1|parent1,y)
                    probs[y] *= commonUtils.P(instance,x1,parent1,y);
                }
                else {
                    //parent1和2均不为-1 为常规点 联合概率乘P(x1|parent1,parent2,y)
                    //P(x1|x2,x3,y) = (F(x1,x2,x3,y) + 1 / F(x1))/ (1 + F(x2,x3,y)) 公式太长了 分开算吧
                    probs[y] *= commonUtils.P(instance,x1,parent1,parent2,y);
                }
            }
        }
        //local模型计算联合概率
        for (int x1 = 0; x1 < m_NumAttributes; x1++) {
            if(x1 == m_ClassIndex)
                continue;
            int parent1 = m_parentsLocal[x1][0];              //父结点1的序号
            int parent2 = m_parentsLocal[x1][1];              //父结点2的序号

            for (int y = 0; y < m_NumClasses; y++) {
                if (parent1 == -1){
                    //说明是互信息最大的那个点，联合概率乘P(x1|y)
                    probsLocal[y] *= commonUtils.P(instance,x1,y);
                }
                else if (parent2 == -1){
                    //parent1不为-1   parent2为-1 说明是互信息第二大的那个点，联合概率乘P(x1|parent1,y)
                    //P(X0=x0|X1=x1,Y=y) = (F(x0,x1,y) + 1/F(x0)) / (1+F(x1,y))
                    probsLocal[y] *= commonUtils.P(instance,x1,parent1,y);
                }
                else {
                    //parent1和2均不为-1 为常规点 联合概率乘P(x1|parent1,parent2,y)
                    //P(x1|x2,x3,y) = (F(x1,x2,x3,y) + 1 / F(x1))/ (1 + F(x2,x3,y)) 公式太长了 分开算吧
                    probsLocal[y] *= commonUtils.P(instance,x1,parent1,parent2,y);
                }
            }
        }
        Utils.normalize(probsLocal);
        Utils.normalize(probs);

        for (int i = 0; i < probs.length; i++) {
            probs[i] = probs[i] + probsLocal[i];
            probs[i] = probs[i] / 2;
        }

        return probs;
    }

    /**
     * toString方法，本方法可以显示在weka的输出结果当中，可以帮助我们判断拿到的结果是出自哪个分类器
     * @return 描述本方法的字符串
     */
    @Override
    public String toString() {
        return "TargetLearningKDB";
    }

    /**
     * 主方法 用于运行分类器 runClassifier(new KDB(), args);语句中的KDB()要与分类器名字(即本java文件的主类名字)一一对应
     * @param args 命令行参数
     */
    public static void main(String[] args) {
        runClassifier(new TargetLearningKDB(), args);
    }
}
